//
// Created by enemy1205 on 2021/8/31.
//
#include "Night_Mode.h"
using namespace cv;
//获取照明通道
std::pair<Mat, Mat> get_illumination_channel(Mat I, float w) {
    int N = I.size[0];
    int M = I.size[1];
    Mat darkch = Mat::zeros(Size(M, N), CV_32FC1);//暗通道
    Mat brightch = Mat::zeros(Size(M, N), CV_32FC1);//亮通道

    //填充所有边缘的明暗通道
    int padding = int(w/2);
    Mat padded = Mat::zeros(Size(M + 2*padding, N + 2*padding), CV_32FC3);

    for (int i=padding; i < padding + M; i++) {
        for (int j=padding; j < padding + N; j++) {
            padded.at<Vec3f>(j, i).val[0] = (float)I.at<Vec3b>(j-padding, i-padding).val[0]/255;
            padded.at<Vec3f>(j, i).val[1] = (float)I.at<Vec3b>(j-padding, i-padding).val[1]/255;
            padded.at<Vec3f>(j, i).val[2] = (float)I.at<Vec3b>(j-padding, i-padding).val[2]/255;
        }
    }

    for (int i=0; i < darkch.size[1]; i++) {
        int col_up, row_up;

        col_up = int(i+w);

        for (int j=0; j < darkch.size[0]; j++) {
            double minVal, maxVal;

            row_up = int(j+w);

            //获取大小为w的每个窗口中的最小和最大像素值
            minMaxLoc(padded.colRange(i, col_up).rowRange(j, row_up), &minVal, &maxVal);

            //Dark channel obtained using minMaxLoc to get the lowest pixel value in that block
            darkch.at<float>(j,i) = minVal;

            //Bright channel obtained using minMaxLoc to get the highest pixel value in that block
            brightch.at<float>(j,i) = maxVal;
        }
    }

    return std::make_pair(darkch, brightch);
}

Mat get_atmosphere(Mat I, Mat brightch, float p=0.1) {
    int N = brightch.rows;
    int M = brightch.cols;

    Mat flatI(Size(1, N*M), CV_8UC3);
    std::vector<std::pair<float, int>> flatBright;

    //Flattening the image I into flatI
    for (int i=0; i < M; i++) {
        for (int j=0; j < N; j++) {
            int index = i*N + j;
            flatI.at<Vec3b>(index, 0).val[0] = I.at<Vec3b>(j, i).val[0];
            flatI.at<Vec3b>(index, 0).val[1] = I.at<Vec3b>(j, i).val[1];
            flatI.at<Vec3b>(index, 0).val[2] = I.at<Vec3b>(j, i).val[2];

            //Storing the bright channels in flatBright vector along with index inorder to get the sorted values as well as indicies
            flatBright.emplace_back(-brightch.at<float>(j, i), index);
        }
    }


    //Sorting according to maximum intensity and slicing the array to include only the top ten percent (p = 0.1) of pixels
    //To get descending order, added -ve sign to the flatBright values
    sort(flatBright.begin(), flatBright.end());

    Mat A = Mat::zeros(Size(1, 3), CV_32FC1);

    for (int k=0; k < int(M*N*p); k++) {
        int sindex = flatBright[k].second;
        A.at<float>(0, 0) = A.at<float>(0, 0) + (float)flatI.at<Vec3b>(sindex, 0).val[0];
        A.at<float>(1, 0) = A.at<float>(1, 0) + (float)flatI.at<Vec3b>(sindex, 0).val[1];
        A.at<float>(2, 0) = A.at<float>(2, 0) + (float)flatI.at<Vec3b>(sindex, 0).val[2];
    }

    A = A/int(M*N*p);

    return A/255;
}

Mat get_initial_transmission(Mat A, Mat brightch) {
    double A_n, A_x, minVal, maxVal;
    minMaxLoc(A, &A_n, &A_x);
    Mat init_t(brightch.size(), CV_32FC1);
    init_t = brightch.clone();

    //Finding initial transmission map according to the above formula
    init_t = (init_t - A_x)/(1.0 - A_x);

    //Normalize initial transmission map
    minMaxLoc(init_t, &minVal, &maxVal);
    init_t = (init_t - minVal)/(maxVal - minVal);

    return init_t;
}


Mat reduce_init_t(Mat init_t) {
    Mat mod_init_t(init_t.size(), CV_8UC1);

    //The transmission map received was normalized so it is converted to pixels	having values between 0-255
    for (int i=0; i < init_t.size[1]; i++) {
        for (int j=0; j < init_t.size[0]; j++) {
            mod_init_t.at<uchar>(j, i) = std::min((int)(init_t.at<float>(j, i)*255), 255);
        }
    }

    int x[3] = {0, 32, 255};
    int f[3] = {0, 32, 48};

    Mat table(Size(1, 256), CV_8UC1);

    //Interpreting f according to x in range of k
    int l = 0;
    for (int k = 0; k < 256; k++) {
        if (k > x[l+1]) {
            l = l + 1;
        }

        float m  = (float)(f[l+1] - f[l])/(x[l+1] - x[l]);
        table.at<int>(k, 0) = (int)(f[l] + m*(k - x[l]));
    }

    //Lookup table
    LUT(mod_init_t, table, mod_init_t);

    //The transmission map is normalized before returning it
    for (int i=0; i < init_t.size[1]; i++) {
        for (int j=0; j < init_t.size[0]; j++) {
            init_t.at<float>(j, i) = (float)mod_init_t.at<uchar>(j, i)/255;
        }
    }

    return init_t;
}

Mat get_corrected_transmission(Mat I, Mat A, Mat darkch, Mat brightch, Mat init_t, float alpha, float omega, int w) {
    Mat im3(I.size(), CV_32FC3);

    for (int i=0; i < I.size[1]; i++) {
        for (int j=0; j < I.size[0]; j++) {
            im3.at<Vec3f>(j, i).val[0] = (float)I.at<Vec3b>(j, i).val[0]/A.at<float>(0, 0);
            im3.at<Vec3f>(j, i).val[1] = (float)I.at<Vec3b>(j, i).val[1]/A.at<float>(1, 0);
            im3.at<Vec3f>(j, i).val[2] = (float)I.at<Vec3b>(j, i).val[2]/A.at<float>(2, 0);
        }
    }

    Mat dark_c, dark_t, diffch;

    //Getting dark channel transmission map
    std::pair<Mat, Mat> illuminate_channels = get_illumination_channel(im3, w);
    dark_c = illuminate_channels.first;

    //Finding dark transmission map using omega (0.75) which will be used to correct the initial transmission map.
    dark_t = 1 - omega*dark_c;

    //Initializing corrected transmission map with initial transmission map as its values will remain
    //the same as the initial transmission map when the difference between them is less than alpha
    Mat corrected_t = init_t;

    //Finding difference between transmission maps
    diffch = brightch - darkch;

    for (int i=0; i < diffch.size[1]; i++) {
        for (int j=0; j < diffch.size[0]; j++) {

            //if difference between the transmission greater than alpha (0.4) the transmission map is corrected by
            //taking their product
            if (diffch.at<float>(j, i) < alpha) {
                corrected_t.at<float>(j, i) = abs(dark_t.at<float>(j, i)*init_t.at<float>(j, i));
            }
        }
    }

    return corrected_t;
}

Mat get_final_image(Mat I, Mat A, Mat refined_t, float tmin) {
    Mat J(I.size(), CV_32FC3);

    for (int i=0; i < refined_t.size[1]; i++) {
        for (int j=0; j < refined_t.size[0]; j++) {
            //Value of refined_t (2D refined map) at (j, i) is considered if it is >= tmin.
            float temp = refined_t.at<float>(j, i);

            if (temp < tmin) {
                temp = tmin;
            }

            //Finding result using the formula given at top
            J.at<Vec3f>(j, i).val[0] = (I.at<Vec3f>(j, i).val[0] - A.at<float>(0,0))/temp + A.at<float>(0,0);
            J.at<Vec3f>(j, i).val[1] = (I.at<Vec3f>(j, i).val[1] - A.at<float>(1,0))/temp + A.at<float>(1,0);
            J.at<Vec3f>(j, i).val[2] = (I.at<Vec3f>(j, i).val[2] - A.at<float>(2,0))/temp + A.at<float>(2,0);
        }
    }

    double minVal, maxVal;
    minMaxLoc(J, &minVal, &maxVal);

    //Normalize the image J
    for (int i=0; i < J.size[1]; i++) {
        for (int j=0; j < J.size[0]; j++) {
            J.at<Vec3f>(j, i).val[0] = (J.at<Vec3f>(j, i).val[0] - minVal)/(maxVal - minVal);
            J.at<Vec3f>(j, i).val[1] = (J.at<Vec3f>(j, i).val[1] - minVal)/(maxVal - minVal);
            J.at<Vec3f>(j, i).val[2] = (J.at<Vec3f>(j, i).val[2] - minVal)/(maxVal - minVal);
        }
    }

    return J;
}

//降噪
Mat dehaze(Mat img, float tmin=0.1, int w = 15, float alpha=0.4, float omega=0.75, float p=0.1, double eps=1e-3, bool reduce=false) {
    std::pair<Mat, Mat> illuminate_channels = get_illumination_channel(img, w);
    Mat Idark = illuminate_channels.first;
    Mat Ibright = illuminate_channels.second;

    Mat A = get_atmosphere(img, Ibright);

    Mat init_t = get_initial_transmission(A, Ibright);

    if (reduce) {
        init_t = reduce_init_t(init_t);
    }

    Mat corrected_t = get_corrected_transmission(img, A, Idark, Ibright, init_t, alpha, omega, w);

    Mat I(img.size(), CV_32FC3), normI;

    for (int i=0; i < img.size[1]; i++) {
        for (int j=0; j < img.size[0]; j++) {
            I.at<Vec3f>(j, i).val[0] = (float)img.at<Vec3b>(j, i).val[0]/255;
            I.at<Vec3f>(j, i).val[1] = (float)img.at<Vec3b>(j, i).val[1]/255;
            I.at<Vec3f>(j, i).val[2] = (float)img.at<Vec3b>(j, i).val[2]/255;
        }
    }

    double minVal, maxVal;
    minMaxLoc(I, &minVal, &maxVal);
    normI = (I - minVal)/(maxVal - minVal);

    //Applying guided filter
    Mat refined_t(normI.size(), CV_32FC1);
    refined_t = guidedFilter(normI, corrected_t, w, eps);

    Mat J_refined = get_final_image(I, A, refined_t, tmin);

    Mat enhanced(img.size(), CV_8UC3);

    for (int i=0; i < img.size[1]; i++) {
        for (int j=0; j < img.size[0]; j++) {
            enhanced.at<Vec3b>(j, i).val[0] = std::min((int)(J_refined.at<Vec3f>(j, i).val[0]*255), 255);
            enhanced.at<Vec3b>(j, i).val[1] = std::min((int)(J_refined.at<Vec3f>(j, i).val[1]*255), 255);
            enhanced.at<Vec3b>(j, i).val[2] = std::min((int)(J_refined.at<Vec3f>(j, i).val[2]*255), 255);
        }
    }

    Mat f_enhanced;
    detailEnhance(enhanced, f_enhanced, 10, 0.15);
    edgePreservingFilter(f_enhanced, f_enhanced, 1, 64, 0.2);

    return f_enhanced;
}

void run(std::string path)  {
    Mat img = imread(path);
    Mat out_img = dehaze(img);
    Mat out_img2 = dehaze(img,0.1,15,0.4,0.75,0.1,1e-3,true);
    imshow("original",img);
    imshow("F_enhanced", out_img);
    imshow("F_enhanced2", out_img2);
    waitKey(0);
}